# -*- coding: utf-8 -*-
# @Time  : 2021/3/20 11:41
# @Author : zhoujiangtao
# @Desc : ==============================================
# Life is Short I Use Python!!!                      
# If this runs wrong,don't ask me,I don't know why;  
# If this runs right,thank god,and I don't know why. 
# Maybe the answer,my friend,is blowing in the wind. 
# ======================================================

from skimage import io
import torch
import matplotlib.pyplot as plt


def show_img(img_t,title = "image"):
    plt.imshow(img_t, cmap="gray")
    plt.title(title)
    plt.show()


def get_gray_img(img_f):
    img_t = torch.from_numpy(io.imread(img_f, as_gray=True))
    return img_t


def show_imgtensor_message(img_t):
    print(img_t)
    print("shape:{}".format(img_t.shape))
    img_t_flat = img_t.view(1, -1)
    print("max:{}".format(img_t_flat.max(dim=1)))
    print("min:{}".format(img_t_flat.min(dim=1)))


def binarization(img_t):
    z = torch.zeros(40, 80).double()
    o = torch.ones(40, 80).double()
    img_t_b = torch.where(img_t < img_t.mean().item(), z, o)
    return img_t_b


def median_filter(t, kennel=3):
    w = t.shape[1]
    h = t.shape[0]
    for i in range(h - kennel):
        for j in range(w - kennel):
            t[i, j] = t[i:i + kennel, j:j + kennel].median().float().item()
    return t

# margin分别为图片的[上，下，左，右]
def spilt(t, label, margin, slice=4):
    ls_image = []
    ls_lab = []
    for i in range(slice):
        s_p = t[0 + margin[0]:40 - margin[1], 20 * i + margin[2]:20 * (i + 1) - margin[3]]
        ls_image.append(s_p)
        ls_lab.append(label[i])
    return ls_image, ls_lab

def show_images(images,labels):
    plt.figure()
    for i in range(len(images)):
        plt.subplot(1, len(images), i+1)
        plt.suptitle(labels)
        plt.imshow(images[i])
    plt.show()

import os

def preprocess_images(image_path = "./data/train/",save2 = "./data/image_train/"):
    fs = os.listdir(image_path)
    for f in fs:
        img_t = get_gray_img("{}{}".format(image_path,f))
        img_t_b = binarization(img_t)
        img_t_m = median_filter(img_t_b)
        image_name = f.replace("png","")
        images, labels = spilt(img_t_m, image_name, margin=[5, 5, 1, 1])
        for image,label in zip(images, labels):
            lab_dir = "{}{}".format(save2,label)
            if(not os.path.exists(lab_dir)):
                os.mkdir(lab_dir)
            # 保存图片需要转化为像素值，灰度值*255即可
            image = (image * 255).type(dtype=torch.uint8)
            # 将2a45中的a图片保存成形如2a45_2.png的形式，方便以后追述
            image_f = "{}/{}_{}.png".format(lab_dir,f.replace("png",""),label)
            io.imsave(image_f,image)


if (__name__ == "__main__"):
    preprocess_images()
    # preprocess_images(image_path = "./data/test/",save2="./data/image_test/")
    # img = "./data/new/2a45.png"
    # img_t = get_gray_img(img)
    # show_img(img_t)
    # show_imgtensor_message(img_t)
    # img_t_b = binarization(img_t)
    # show_img(img_t_b)
    # img_t_m = median_filter(img_t_b)
    # show_img(img_t_b)
    # images, labels = spilt(img_t_m, "2a45", margin=[5, 5, 1, 1])
    # show_images(images,labels)